"""
Major library for graph dictionary learning algorithms

Created on: July 25, 2022

"""
# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring

from time import time
import sys
import os
import pickle
from collections import defaultdict
from multiprocessing import Process, Manager
from ctypes import c_char_p
import heapq
from operator import itemgetter

import numpy as np
from sklearn import linear_model
from sklearn.utils.extmath import randomized_svd
import scipy
import torch

DEBUG = False

def torch_batch_matrix_mul(a, b, batch_size=1024, device="cpu"):
    """
    Compute the matrix multiplication with torch
    """
    if batch_size == -1:
        batch_size = a.shape[0]
    max_batch = int(np.ceil(a.shape[0] / batch_size))
    with torch.no_grad():
        torch_b = torch.from_numpy(b).float().to(device)
        list_res = []
        for i in range(max_batch):
            torch_cur_a = torch.from_numpy(a[i*batch_size:(i+1)*batch_size, :]).float().to(device)
            cur_product = torch_cur_a @ torch_b
            list_res.append(cur_product.cpu().detach().numpy())
    res = np.concatenate(list_res)
    return res

def np_wthresh(A: np.array, lam: float) -> np.array:
    sign = np.sign(A)
    val = np.abs(A)

    B = A.copy()

    zero_pos = np.where(val<=lam)
    B[zero_pos] =0

    shrink_pos = np.where(val>lam)
    B[shrink_pos] = (val[shrink_pos] - lam) * sign[shrink_pos]

    return B

class SparseDictionaryLearning: # pylint: disable=too-many-instance-attributes
    """
    Solve the sparse dictionary learning problem with potential weight. Given signal
    A={A1^T; A2^T;...; An^T} (n by F), whose row vectors Ai^T are signals. We wish to learn a
    dictionary O (c by n), whose rows are base vector, along with coefficient a={a1, a2, ..., ac}
    with ai (n), so that we can appoximate Ai by Ai^T=ai^T O. To find O and a, we need to solve the
    optimization problem (Original)
        minimize_{O, a} \sum_z 1/2|a_z^T O - Az^T|_2^2+\lambda |a_z|_1,
        s.t. O_{i*}O_{i*}^T \le 1, \forall i.
    Another variant is to reweight the objective with some given signal Y. And we wich to
    solve (Weighted)
        minimize_{O, a} \sum_z 1/2|a_z^T O Y- Az^T Y|_2^2+\lambda |a_z|_1,
        s.t. O_{i*}O_{i*}^T \le 1, \forall i.
    This class solves the above two problems.

    Variables:
    -----------
        O: the dictionary
        a: the coefficients
        log: training log for various statistics

    Methods:
    -----------

    """
    Q_INIT_DIAG_VAL = 1.0e-6 # small value to initialize Q diagnal
    O_NONACTIVE_THRESHOLD = 1.0e-3 # threshold to detect nonactive row of O
    NORMAL_VAR = 0.01

    def __init__(self, # pylint: disable=too-many-arguments, too-many-statements
                 A,
                 dict_size,
                 F=None,
                 weighted=True,
                 epoch=5,
                 batch_size=64,
                 eval_step=20,
                 eval_batch_size=-1,
                 a_method='lasso_lars',
                 lam=1.0e-6,
                 n_a_nonzero=20,
                 shuffle=True,
                 num_worker=1,
                 O_Q_ST_accurate=True,
                 O_loop_cnt=1,
                 O_init_method='random_select',
                 O_resample_method='no',
                 O_resample_warmup=5,
                 O_resample_step=20,
                 device="cpu",
                 verbose=True,
                 **kwargs):
        """
        Parameters:
        -----------
            A: (normalized) adjacent matrix representing the graph
            dict_size: how many base to use in the dictionary
            F: Reweiting matrix for the weighted version
            weighted: swtich the mode between the Original solver or the weighted solver
            epoch: maximum number of training epoch
            batch_size: batch size in each training iteration
            eval_step: how many iterations to run before an evaluation step
            eval_batch_size: evaluation batch size, -1 for full batch
            a_method: optimizer to solve the a in phase a. lars, lasso_lars
            lam: lasso regularizer for solving a when the a_method is lasso_lars
            n_a_nonzero: the number of nonzeros for solving a when the a_method is lars
            shuffle: whether shuffle the sampling order in each epoch
            num_worker: number of works to compute phase a
            O_Q_ST_accurate: use SGD cache or accurate Q ST for optimizing O
            O_loop_cnt: iteration in optimizing the O phase
            O_init_method: initialization method for O
            O_resample_method: approach to resample unused rows of O. 'no' for no resampling,
                               'uniform' for uniform resampling, 'greedy' for greedily resample
                               the rows that a has the largest loss.
            O_resample_warmup: warm up iterations before execute the resampling of O
            O_resample_step: how many iterations to execte before the resampling O
            verbose: whether to display intermeidate results
        """
        self.A = A
        self.dict_size = dict_size
        self.F = F
        self.weighted = weighted
        self.epoch = epoch
        self.batch_size = batch_size
        self.eval_step = eval_step
        self.eval_batch_size = eval_batch_size
        self.a_method = a_method
        self.lam = lam
        self.n_a_nonzero = n_a_nonzero
        self.shuffle = shuffle
        self.num_worker = num_worker
        self.O_Q_ST_accurate = O_Q_ST_accurate
        self.O_loop_cnt = O_loop_cnt
        self.O_resample_method = O_resample_method
        self.O_resample_warmup = O_resample_warmup
        self.O_resample_step = O_resample_step
        self.O_init_method = O_init_method
        self.verbose = verbose
        self.device = device
        self.kwargs = kwargs

        if weighted:
            if self.verbose:
                print("Weighted version.")
            if self.F is None:
                raise ValueError("F should not be none under weighted mode! ")
            if "AF" in self.kwargs:
                self.AF = self.kwargs["AF"]
            else:
                self.AF = A.dot(F)
            self.phase_a = self.phase_a_weighted
        else:
            if self.verbose:
                print("Original version.")
            self.phase_a = self.phase_a_original

        self.data_size, self.feature_dim = self.A.shape # feature dimension and the data size
        self.global_iter_cnt = None
        self.log = defaultdict(list)
        self._iter_cnt, self._sample_idx = None, None
        self._max_iter = int(np.ceil(self.data_size / self.batch_size))

        self.lam /= self.feature_dim # renormalization

        self.O, self.a, self.a_old = None, None, None # dictionary, coefficients, coefficients cache
        self.Q, self.ST = None, None
        self.per_sample_regret = None # per sample regret to pick resampled O

        self._check_param()

    def _check_param(self):
        if self.dict_size > self.feature_dim:
            raise ValueError("Does not support the case "
                             "that dict size is larger than feature dimension! "
                             f"Current dict size {self.dict_size} and "
                             f"current feature dimension {self.feature_dim}")

    def _sampler(self):
        """handle the random mini-batch indices"""
        # initialize the parameters
        self._sample_idx = np.arange(self.data_size)
        if self.shuffle:
            np.random.shuffle(self._sample_idx)
        # main loop
        for self._iter_cnt in range(self._max_iter):
            yield self._sample_idx[self._iter_cnt*self.batch_size:
                                   min((self._iter_cnt+1)*self.batch_size, self.data_size)]

    @staticmethod
    def eval_metrics_torch(O, a, A, F, eval_batch_size=-1, AF=None, device="cpu", verbose=True): # pylint: disable=too-many-locals
        """
        Evaluating the performance metrics
            All written together to boost the efficiency and reduce redundant computation
        """
        if verbose:
            print("---- Evaluating...")

        if F is not None and AF is None:
            raise ValueError("AF should be provided if F is not none. ")

        _, feature_dim = A.shape
        regret, rel_regret, weighted_regret, rel_weighted_regret = np.NAN, np.NAN, np.NAN, np.NAN

        list_DRSS, list_SNSS, list_WDRSS, list_WSNSS = [], [], [], []
        if eval_batch_size == -1:
            eval_batch_size = feature_dim
        max_batch = int(np.ceil(feature_dim/eval_batch_size))

        with torch.no_grad():
            tch_O = torch.from_numpy(O).to(device)
            tch_F = torch.from_numpy(F).to(device)

            t_start = time()
            for bb in range(max_batch):
                cur_a = torch.from_numpy(a[bb*eval_batch_size:min((bb+1)*eval_batch_size,
                                           feature_dim), :]).float().to(device)
                cur_A = torch.from_numpy(A[bb*eval_batch_size:min((bb+1)*eval_batch_size,
                                         feature_dim)].toarray()).float().to(device)
                a_dot_O = torch.mm(cur_a, tch_O)
                # regret
                diff_row_square_sum = torch.norm(a_dot_O - cur_A, dim=1) ** 2
                self_norm_square_sum = torch.norm(cur_A, dim=1) ** 2
                list_DRSS.append(diff_row_square_sum.cpu().detach().numpy())
                list_SNSS.append(self_norm_square_sum.cpu().detach().numpy())
                # weighted regret
                if F is not None:
                    cur_AF =\
                        torch.from_numpy(
                            AF[bb*eval_batch_size:min((bb+1)*eval_batch_size,
                                    feature_dim), :]).float().to(device)
                    a_dot_O_dot_F = torch.mm(a_dot_O, tch_F)
                    weighted_diff_row_square_sum = torch.norm(a_dot_O_dot_F - cur_AF, dim=1) ** 2
                    weighted_self_norm_square_sum = torch.norm(cur_AF, dim=1) ** 2
                    list_WDRSS.append(weighted_diff_row_square_sum.cpu().detach().numpy())
                    list_WSNSS.append(weighted_self_norm_square_sum.cpu().detach().numpy())
                if verbose:
                    cur_time = time() - t_start
                    ETA = cur_time / (bb+1) * (max_batch - bb - 1)
                    sys.stdout.write(f"{bb+1}/{max_batch} finished in {cur_time:.1f}s, "
                                     f"ETA {ETA:.1f}\r")
                    sys.stdout.flush()
            if verbose:
                print()
            DRSS = np.concatenate(list_DRSS)
            SNSS = np.concatenate(list_SNSS)
            regret = np.sum(DRSS)
            per_sample_rel_regret = DRSS / SNSS
            rel_regret = np.mean(per_sample_rel_regret)
            # weighted regret
            if F is not None:
                WDRSS = np.concatenate(list_WDRSS)
                WSNSS = np.concatenate(list_WSNSS)
                weighted_regret = np.sum(WDRSS)
                per_sample_rel_weighted_regret = WDRSS / WSNSS
                rel_weighted_regret = np.mean(per_sample_rel_weighted_regret)
                cached_per_sample_regret = per_sample_rel_weighted_regret
            else:
                cached_per_sample_regret = per_sample_rel_regret
        return regret, rel_regret, weighted_regret, rel_weighted_regret, cached_per_sample_regret

    @staticmethod
    def _init_O_a(dict_size, target, method="random_svd"):
        data_size = target.shape[0]
        if "random_select" == method:
            seed = np.random.choice(list(range(data_size)), dict_size, replace=False)
            O = target[seed]
            if not isinstance(target, np.ndarray):
                O = O.toarray()
            O += np.random.normal(0, SparseDictionaryLearning.NORMAL_VAR, O.shape)
            a = np.zeros([data_size, dict_size])
        elif "random_svd" == method:
            a, S, dictionary = randomized_svd(target, dict_size, random_state=0)
            O = S[:, np.newaxis] * dictionary
        else:
            raise ValueError("O initialization method could be random_select, random_svd."
                             f"Undefined method for initilize O: {method}.")
        return O, a

    @staticmethod
    def _init_QS(dict_size, feature_dim):
        Q = SparseDictionaryLearning.Q_INIT_DIAG_VAL * np.identity(dict_size)
        S = np.zeros([dict_size, feature_dim])
        return Q, S

    @staticmethod
    def _compute_a(v, O, gram, Xy, method='lasso_lars', lam=1.0e-3, n_a_nonzero=20):
        """
        Keep this structure for multi-processing extension
        """
        if method == "lasso_lars":
            reg = linear_model.LassoLars(alpha=lam, precompute=gram,
                                         normalize=False, fit_intercept=False)
            reg.fit(X=O, y=v, Xy=Xy)
            res = reg.coef_
        elif method == "max_a_nonzero":
            if isinstance(n_a_nonzero, list):
                max_iter = n_a_nonzero[-1]
            else:
                max_iter = n_a_nonzero
            reg = linear_model.LassoLars(alpha=1e-7, precompute=gram, max_iter=max_iter,
                                         normalize=False, fit_intercept=False)
            reg.fit(X=O, y=v, Xy=Xy)
            res = reg.coef_
        else:
            raise ValueError("a phase method could only be lasso_lars, lars."
                             f"Unrecognized a method {method}.")
        return res

    @staticmethod
    def _thread_wrapper(p_str_res, ii, v, O, gram, Xy, method, lam, n_a_nonzero): # pylint: disable=too-many-arguments
        res = SparseDictionaryLearning._compute_a(v,
                                                  O,
                                                  gram,
                                                  Xy,
                                                  method=method,
                                                  lam=lam,
                                                  n_a_nonzero=n_a_nonzero)
        p_str_res.value = pickle.dumps([ii, res], protocol=0)


    def compute_a(self, v, O, method='lasso_lars', lam=1.0e-3, n_a_nonzero=20,
                  num_worker=1):
        num_a, _ = v.shape
        # precomputation to speed up lars solver
        ## *** Normalize the necessary variables to remove the repeated normalization
        ## Need to scale everyting back after the optimization ***
        X = O.transpose()
        y = v.transpose()
        # X, y, X_offset, y_offset, X_scale = linear_model._base._preprocess_data(
        #                                         X, y, True, True, True
        #                                     ) # fit_intercept, normalize, copy_X
        gram = torch_batch_matrix_mul(X.transpose(), X, batch_size=self.eval_batch_size,
                                      device=self.device)
        Xy = np.dot(X.transpose(), y)

        if num_worker == 1:
            res = SparseDictionaryLearning._compute_a(y,
                                                      X,
                                                      gram,
                                                      Xy,
                                                      method=method,
                                                      lam=lam,
                                                      n_a_nonzero=n_a_nonzero)
            res = np.array(res)
        else:
            manager = Manager()
            res_cache = []
            my_procs = []
            block_size = int(np.ceil((num_a / num_worker)))
            for ii in range(num_worker):
                res_cache.append(manager.Value(c_char_p, ""))
                my_y = y[:, ii*block_size:min(num_a, (ii+1)*block_size)]
                my_procs.append(Process(target=SparseDictionaryLearning._thread_wrapper,
                                        args=(res_cache[-1],
                                              ii,
                                              my_y,
                                              X,
                                              gram,
                                              Xy,
                                              method,
                                              lam,
                                              n_a_nonzero)))
            for i in range(num_worker):
                my_procs[i].start()
            for i in range(num_worker):
                my_procs[i].join()
            res_ind, res = [], []
            for i in range(num_worker):
                tmp_ind, tmp_res = pickle.loads(res_cache[i].value)
                res_ind.append(tmp_ind)
                res.append(tmp_res)
            ind_order = np.argsort(res_ind)
            res = [res[ind_order[i]] for i in range(num_worker)]
            res = np.concatenate(res, axis=0)

            # end of multi worker processing

        # *** Project the normalized variables back ***

        return res

    @staticmethod
    def update_O_torch(O, Q, ST, O_loop_cnt=1, device="cpu"):
        dict_size, _ = O.shape

        with torch.no_grad():
            torch_O = torch.from_numpy(O).float().to(device)
            torch_Q = torch.from_numpy(Q).float().to(device)
            torch_ST = torch.from_numpy(ST).float().to(device)
            for _ in range(O_loop_cnt):
                for j in range(dict_size):
                    tmp = (torch_ST[j, :] - torch_Q[j, :] @ torch_O) / torch_Q[j, j] + torch_O[j, :]
                    torch_O[j, :] = tmp / torch.clamp(torch.norm(tmp), min=1)
        O[:] = torch_O.cpu().detach().numpy()

    def eval_and_log(self):
        # this_regret, this_rel_regret, this_weighted_regret, this_rel_weighted_regret,\
        #     self.per_sample_regret = self.eval_metrics(self.O, self.a, self.A, self.F,
        #                                                eval_batch_size=self.eval_batch_size)
        this_regret, this_rel_regret, this_weighted_regret, this_rel_weighted_regret,\
            self.per_sample_regret = self.eval_metrics_torch(self.O, self.a, self.A, self.F,
                                                             eval_batch_size=self.eval_batch_size,
                                                             AF = self.AF, device=self.device,
                                                             verbose=self.verbose)
        self.log["regret"].append(this_regret)
        self.log["rel_regret"].append(this_rel_regret)
        self.log["weighted_regret"].append(this_weighted_regret)
        self.log["rel_weighted_regret"].append(this_rel_weighted_regret)

        self.log["a_row_nonzero"].append(np.sum(self.a!=0, axis=1))
        self.log["O_row_nonzero"].append(np.sum(self.O!=0, axis=1))

    def _display_stat(self):
        print(f"Weighted regret: {self.log['weighted_regret'][-1]:.6f},"
              f" | Rel weighted regret: {self.log['rel_weighted_regret'][-1]:.6f}"
              f" | Regret: {self.log['regret'][-1]:.6f}"
              f" | Rel Regret: {self.log['rel_regret'][-1]:.6f}")
        print(f"  a row nonzeros: {np.mean(self.log['a_row_nonzero'][-1]):.4f} " + u"\u00B1"
              f" {np.std(self.log['a_row_nonzero'][-1]):.4f}"
              f" | O row nonzeros: mean {np.mean(self.log['O_row_nonzero'][-1]):.4f} " + u"\u00B1"
              f" {np.std(self.log['O_row_nonzero'][-1]):.4f}")

    def fit(self):
        """Pure fit without O sparsification."""
        tic_all_start = time()
        # initialization
        if self.verbose:
            print("Initializing...")
        self.O, self.a = self._init_O_a(self.dict_size, self.A, method=self.O_init_method)
        if self.O_Q_ST_accurate:
            self.a_old = self.a.copy()
        self.Q, self.ST = self._init_QS(self.dict_size, self.feature_dim)

        self.global_iter_cnt = 0
        total_iter = self.epoch * self._max_iter
        # main loop
        self.eval_and_log()
        if self.verbose:
            print("Training started...")
            self._display_stat()
            print("-"*27)
            print()
        tic_start = time()
        for e in range(self.epoch):
            for it, cur_idx in enumerate(self._sampler()):
                tic_iter_start = time()
                # phase a
                tic_iter_a_start = time()
                cur_a, cur_v = self.phase_a(cur_idx)
                self.log["time_compute_a"].append(time() - tic_iter_a_start)
                # phase O
                tic_iter_O_start = time()
                self.update_state_for_O(cur_a, cur_v, cur_idx)
                if DEBUG:
                    time_O_update = time() - tic_iter_O_start
                # self.update_O(self.O, self.Q, self.ST, O_loop_cnt=self.O_loop_cnt)
                self.update_O_torch(self.O, self.Q, self.ST,
                                    O_loop_cnt=self.O_loop_cnt, device=self.device)
                self.log["time_compute_O"].append(time() - tic_iter_O_start)

                self.log["time_iter"].append(time() - tic_iter_start)

                if DEBUG:
                    print(f"*** Time a {self.log['time_compute_a'][-1]:.1f} | "
                          f"Time O {self.log['time_compute_O'][-1]:.1f} | "
                          f"Time O update {time_O_update:.1f}")

                self.global_iter_cnt += 1
                if self.global_iter_cnt % self.O_resample_step == 0:
                    self.resample_O(self.O_resample_method)
                # evaluation and display
                if self.global_iter_cnt % self.eval_step == 0:
                    self.eval_and_log()
                    if self.verbose:
                        elapsed_time = time() - tic_start
                        ETA = elapsed_time /\
                            self.global_iter_cnt * (total_iter - self.global_iter_cnt)
                        print(f"Epoch: {e} | Iter: {it} | "
                              f"Elapsed time: {elapsed_time:.1f} | ETA: {ETA:.1f}")
                        self._display_stat()

        # Post update for the parameters
        if self.verbose:
            print("-"*27)
            print()
            print("Started post processing...")
        self.post_update()
        self.eval_and_log()
        if self.verbose:
            print(f"Training finished in {(time()-tic_all_start):.1f} s.")
            self._display_stat()
        return self

    def post_update(self, O_is_sparsify=False, O_sparsify_conf=None):
        """
        Post processing for the weights
        """
        if O_is_sparsify:
            # use percentile thresholing instead of the full version
            if self.verbose:
                print("Using O sparsify mode. Sparsifing O...")
            self.O, signal_list = self.sparsify_O(O_sparsify_conf)
            self.log["signal_list"] = signal_list
        if self.verbose:
            print("Recompute a...")
        tic_start = time()
        for it, cur_idx in enumerate(self._sampler()):
            self.phase_a(cur_idx)
            if self.verbose:
                cur_time = time() - tic_start
                ETA = cur_time / (it + 1) * (self._max_iter - it - 1)
                print(f'{it/self._max_iter:.3f} finished in {cur_time:.1f} s. ETA: {ETA:.1f} s',
                      end='\r', flush=True)
        if self.verbose:
            print()

    def sparsify_O(self, O_sparsify_conf):
        """
        sparsify O according to some approach
        """
        def get_remove_idx(O_row, O_sparsify_conf):
            cur_val = np.abs(O_row)
            cur_val = cur_val / np.sum(cur_val)
            sort_idx = np.argsort(-cur_val)
            cur_val = cur_val[sort_idx]
            i, nnz, signal =0, 0, 0
            while nnz < O_sparsify_conf["max_nonzero"] and\
                signal < O_sparsify_conf["signal_ratio"]:
                nnz += 1
                signal += cur_val[i]
                i += 1
            return sort_idx[i:], signal

        new_O = self.O.copy()
        c, _ = new_O.shape
        signal_list = []
        for i in range(c):
            to_remove, signal = get_remove_idx(new_O[i,:], O_sparsify_conf)
            new_O[i, to_remove] = 0
            signal_list.append(signal)
        return new_O, signal_list

    def phase_a_original(self, cur_idx):
        cur_v = self.A[cur_idx, :].toarray()
        cur_a = self.compute_a(cur_v, self.O, method=self.a_method,
                               lam=self.lam, n_a_nonzero=self.n_a_nonzero,
                               num_worker=self.num_worker)
        self.a[cur_idx, :] = cur_a
        return cur_a, cur_v

    def phase_a_weighted(self, cur_idx):
        cur_vF = self.AF[cur_idx, :]
        cur_v = self.A[cur_idx,:].toarray()
        # target_O = self.O.dot(self.F)
        target_O = torch_batch_matrix_mul(self.O, self.F, batch_size=self.eval_batch_size,
                                          device=self.device)
        cur_a = self.compute_a(cur_vF, target_O, method=self.a_method,
                               lam=self.lam, n_a_nonzero=self.n_a_nonzero,
                               num_worker=self.num_worker)
        self.a[cur_idx, :] = cur_a
        return cur_a, cur_v

    def _update_state_for_O_accurate(self, cur_a, cur_v, cur_idx):
        cur_a_old = self.a_old[cur_idx, :]
        self.Q += cur_a.T @ cur_a - cur_a_old.T @ cur_a_old
        self.ST += (cur_a - cur_a_old).T @ cur_v
        self.a_old[cur_idx, :] = cur_a

    def _update_state_for_O_approximate(self, cur_a, cur_v):
        if self.global_iter_cnt < self.batch_size:
            theta = self.global_iter_cnt * self.batch_size
        else:
            theta = self.batch_size ** 2 + self.global_iter_cnt - self.batch_size
        beta = (theta + 1 - self.batch_size) / (theta + 1)

        self.Q = self.Q * beta + cur_a.T @ cur_a
        self.ST = self.ST * beta + cur_a.T @ cur_v

    def update_state_for_O(self, cur_a, cur_v, cur_idx):
        """
        aggregate import informtion for O phase
            The detailed algorithm is from Oneline Dictionary Learning for Sparse Coding
        """
        if self.O_Q_ST_accurate:
            self._update_state_for_O_accurate(cur_a, cur_v, cur_idx)
        else:
            self._update_state_for_O_approximate(cur_a, cur_v)

    def resample_O(self, method):
        if method == "no": # pylint: disable=no-else-return
            return
        elif method not in ["uniform", "greedy"]:
            raise ValueError("O resampling method can only be no, uniform, greedy."
                             f"Invalid O resampling method {method}.")
        if self.global_iter_cnt <= self.O_resample_warmup:
            self.log["num_resampled_O"].append(0)
            return
        if self.verbose:
            print("Resampling O...")
        # get the index to resample
        utility = np.sum(np.abs(self.a), axis=0)
        O_rows_to_replace = np.where(utility <= SparseDictionaryLearning.O_NONACTIVE_THRESHOLD)[0]
        # get the a rows to feed. Idea from
        # https://stackoverflow.com/questions/58070203/
        # find-top-k-largest-item-of-a-list-in-original-order-in-python
        if method == "greedy":
            A_rows_to_feed = heapq.nlargest(len(O_rows_to_replace),
                                            enumerate(self.per_sample_regret),
                                            key=itemgetter(1))
            A_rows_to_feed = [ii for (ii, val) in A_rows_to_feed]
        elif method == "uniform":
            A_rows_to_feed = np.random.choice(list(range(self.data_size)),
                                              size=len(O_rows_to_replace), replace=False)
        # excecute the reample
        for o_row, a_row in zip(O_rows_to_replace, A_rows_to_feed):
            self.O[o_row, :] = self.A[a_row, :].toarray() +\
                np.random.normal(0, SparseDictionaryLearning.NORMAL_VAR, (1, self.O.shape[1]))
            # roughly fill the a to prevent degenerating cases
            self.a[a_row, :] = 0
            self.a[a_row, o_row] = 1.
        # update inner state
        if self.O_Q_ST_accurate:
            self._update_state_for_O_accurate(self.a[A_rows_to_feed, :],
                                              self.A[A_rows_to_feed, :],
                                              A_rows_to_feed)

        self.log["num_resampled_O"].append(len(O_rows_to_replace))

    def save(self, res_folder):
        with open(os.path.join(res_folder, "dict.pkl"), 'wb') as fout:
            pickle.dump([self.O, self.a], fout)
        with open(os.path.join(res_folder, "log.pkl"), "wb") as fout:
            pickle.dump(self.log, fout)
        if self.verbose:
            print(f"Results are saved in {res_folder}.")

    def load(self, res_folder):
        if self.verbose:
            print(f"Loading results from {res_folder}.")
        with open(os.path.join(res_folder, "dict.pkl"), 'rb') as fin:
            [self.O, self.a] = pickle.load(fin)
        with open(os.path.join(res_folder, "log.pkl"), "rb") as fin:
            self.log = pickle.load(fin)
        cur_O_shape = self.O.shape
        if self.dict_size != cur_O_shape[0]:
            raise AttributeError("Dict size from loaded results does not match the configuration"
                                 f"{self.dict_size} != {cur_O_shape}")

    ###########
    ## Unused methods
    @staticmethod
    def update_O(O, Q, ST, O_loop_cnt=1):
        dict_size, _ = O.shape
        for _ in range(O_loop_cnt):
            for j in range(dict_size):
                tmp = (ST[j, :] - np.dot(Q[j, :], O)) / Q[j, j] + O[j, :]
                O[j, :] = tmp / max([1, np.linalg.norm(tmp)])

    @staticmethod
    def weighted_regret(O, a, A, F):
        return np.linalg.norm(np.dot(a, O).dot(F)-A.dot(F)) ** 2

    @staticmethod
    def rel_weighted_regret(O, a, A, F):
        res = []
        total, _ = a.shape
        appro = np.dot(a, O).dot(F)
        src = A.dot(F)
        for idx in range(total):
            self_norm = np.linalg.norm(src[idx, :]) ** 2
            diff_norm = np.linalg.norm(src[idx, :] - appro[idx, :]) ** 2
            res.append(diff_norm/self_norm)
        return np.mean(res)

    @staticmethod
    def regret(O, a, A):
        return np.linalg.norm(np.dot(a, O)-A) ** 2

    @staticmethod
    def rel_regret(O, a, A):
        res = []
        total, _ = a.shape

        appro = np.dot(a, O)
        for idx in range(total):
            self_norm = scipy.sparse.linalg.norm(A[idx, :]) ** 2
            diff_norm = np.linalg.norm(A[idx, :] - appro[idx, :]) ** 2
            res.append(diff_norm/self_norm)
        return np.mean(res)


class SMFRProxLinX(SparseDictionaryLearning):
    """
    TODO: cache repeated computation of X'X, X'Y etc.
    Reduce

    SMFR with proxy linear update
    Inputs:
    - X: Matrix of predictors (n x p)
    - Y: Matrix of responses (n x q)
    - lam1: Regularization factor for ||A||_1
    - lam2: Regularization factor for ||B||_1
    - lam3: Regularization factor for ||A||^2_F
    - nFactorsInit: Initial number of factors. Algorithm will start from
    this number and reduce it until B has full row rank
    Outputs:
    - A: Matrix deriving factors from inputs (p x m)
    - B: Matrix of regression coefficients from factors to outputs (m x q)
    - nFactors: Estimated number of factors
    """
    TOL = 1e-6
    MAX_ITER = 10000
    def __init__(self,
                 X,
                 Y,
                 lam1,
                 lam2,
                 lam3,
                 num_factor_init,
                 adj=None,
                 display_step=20,
                 verbose=True):
        self.X = X
        self.Y = Y
        self.lam1 = lam1
        self.lam2 = lam2
        self.lam3 = lam3
        self.num_factor_init = num_factor_init
        self.adj = adj
        self.display_step = display_step
        self.verbose = verbose

        self.num_preds = X.shape[1]
        self.num_resps = Y.shape[1]

    def get_obj(self, X, Y, A, B):
        return 0.5 * np.linalg.norm(Y - X @ A @ B) ** 2 + self.lam2 * np.sum(np.abs(A)) +\
                  self.lam2 * np.sum(np.abs(B)) + self.lam3 * np.linalg.norm(A)

    def fit(self):
        """
        adj: the related target graph matrix to evaluate graph related metrics. 
        """
        tic_main = time()
        for num_factors in range(self.num_factor_init, 0, -1):
            if self.verbose:
                print(f"Factor: {num_factors} started...")
            tic_factor = time()
            A = np.random.rand(self.num_preds, num_factors)
            B = np.random.rand(num_factors, self.num_resps)
            Aprev = A.copy()
            Bprev = B.copy()
            obj = self.get_obj(self.X, self.Y, A, B)
            t0 = 1
            for i in range(self.MAX_ITER):
                tic_iter = time()
                # --- updating B ---
                t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
                wB = (t0-1)/t
                t0=t
                LB = np.linalg.norm(A.transpose() @ (self.X.transpose() @ self.X) @ A)
                Bhat = B + wB * (B - Bprev)
                Ghat = (A.transpose() @ self.X.transpose() @ (self.X @ A @ Bhat - self.Y))
                Bnew = np_wthresh(Bhat - Ghat / LB, self.lam2 / LB)
                obj_prev = obj
                obj = self.get_obj(self.X, self.Y, A, Bnew)
                Bprev=B.copy()

                # --- if no descent, repeat with no extrapolation ---
                if obj >= obj_prev:
                    Ghat = (A.tranpose() @ self.X.transpose() @ (self.X @ A @ B - self.Y))
                    B = np_wthresh(B - Ghat / LB, self.lam2 / LB)
                else:
                    B=Bnew.copy()
                # --- updating A ---
                t = (1 + np.sqrt(1 + 4 * t0 ** 2)) / 2
                wA = (t0 - 1) / t
                t0=t
                # wA = wB # this is might be the better one
                LA = np.linalg.norm(self.X.transpose() @ self.X) *\
                     np.linalg.norm(B @ B.transpose()) + 2 * self.lam3
                Ahat = A + wA * ( A - Aprev)
                Ghat = - self.X.transpose()  @ self.Y @ B.transpose() +\
                       (self.X.transpose() @ self.X) @ Ahat @ (B @ B.transpose()) +\
                       2 * self.lam3 * Ahat
                Anew = np_wthresh(Ahat - Ghat / LA, self.lam1 / LA)
                obj_prev = obj
                obj = self.get_obj(self.X, self.Y, Anew, B) # tiny difference on the lam3 part.
                Aprev = A.copy()
                # --- if no descent, repeat with no extrapolation ---
                if obj >= obj_prev:
                    Ghat = - self.X.transpose() @ self.Y @ B.transpose() +\
                           (self.X.transpose() @ self.X) @ A @ (B @ B.transpose()) +\
                           2 * self.lam3 * A
                    A = np_wthresh(A - Ghat / LA, self.lam1 / LA)
                else:
                    A = Anew.copy()
                    
                if self.verbose:
                    print(f"Factor: {num_factors}, iter: {i}, obj:{obj:.4f} iter time: {time()-tic_iter:.1f}")
                # --- stopping check ---
                # in the original implementation this might be troublesome
                # since obje_prev has been updated twice already, it's
                obj = self.get_obj(self.X, self.Y, A, B)
                diff = np.abs(obj - obj_prev) / (obj_prev + 1e-8)
                if diff < self.TOL:
                    break
                # end of iterative update for a fixed factor
            if (np.linalg.matrix_rank(B) == num_factors) and\
               (np.linalg.matrix_rank(A) == num_factors):
                break
            # end of iteration on factors
            self.A = A, 
            self.B = B,
            self.num_factors = num_factors
        return A, B, num_factors
